import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.metrics.pairwise import euclidean_distances
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Input, Dropout
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.ensemble import RandomForestClassifier, StackingClassifier
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, accuracy_score
import itertools
from sklearn.neighbors import NearestNeighbors, KNeighborsClassifier
import sys
import xgboost as xgb
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC
from sklearn.naive_bayes import GaussianNB
from sklearn.linear_model import LinearRegression, LogisticRegression
from collections import Counter

class MultiClassFramework:

    def __init__(self, encoding_dim=10, confidence_threshold=0.6, knn_neighbors=5,
                 mislabeled_conf_threshold=0.7, low_conf_threshold=0.3,
                 ambiguous_diff_threshold=0.2, inappropriateness_outlier_weight=0.3,
                 max_removal_combinations=3, adaptive_threshold=True,
                 max_lp_iters=200, lp_epsilon=1e-6,
                 IR=None, N=None, F=None, P=None, verbose=True,
                 final_classifier_type='random_forest', 
                 xgboost_params=None, 
                 stacked_ensemble_params=None 
                ):

        self.scaler = StandardScaler()
        self.encoding_dim = encoding_dim
        self.confidence_threshold = confidence_threshold
        self.knn_neighbors = knn_neighbors

        self.mislabeled_conf_threshold = mislabeled_conf_threshold
        self.low_conf_threshold = low_conf_threshold
        self.ambiguous_diff_threshold = ambiguous_diff_threshold
        self.inappropriateness_outlier_weight = inappropriateness_outlier_weight
        self.max_removal_combinations = max_removal_combinations
        self.adaptive_threshold = adaptive_threshold
        self.verbose = verbose

        self.max_lp_iters = max_lp_iters
        self.lp_epsilon = lp_epsilon

        self.autoencoder = None
        self.encoder = None
        self.final_classifier = None

        self.original_labeled_log = {}
        self.all_original_classes = None
        self.n_classes = None

        self._original_to_internal_label_map = {}
        self._internal_to_original_label_map = {}

        self.IR = IR
        self.N = N
        self.F = F
        self.P = P

        self.outlier_reconstruction_iqr_multiplier_dynamic = self._calculate_outlier_reconstruction_iqr_multiplier()
        self.inappropriateness_f1_weight_dynamic = self._calculate_inappropriateness_f1_weight()
        self.alpha_lp_dynamic = None # Will be calculated in phase2 as it depends on n_classes

        self.final_classifier_type = final_classifier_type
        self.xgboost_params = xgboost_params if xgboost_params is not None else {}
        self.stacked_ensemble_params = stacked_ensemble_params if stacked_ensemble_params is not None else {}


    def _log(self, message):
        
        if self.verbose:
            print(message)

    def _determine_label_mapping(self, y_data):
        
        actual_labels = [label for label in y_data if str(label).strip() != 'unlabeled'] 
        original_labels = np.unique(actual_labels)

        self.all_original_classes = original_labels
        self.n_classes = len(original_labels)

        sorted_original_labels = sorted(original_labels.tolist())
        self._original_to_internal_label_map = {
            orig_label: i for i, orig_label in enumerate(sorted_original_labels)
        }
        self._internal_to_original_label_map = {
            i: orig_label for i, orig_label in enumerate(sorted_original_labels)
        }
        self._log(f"  Detected {self.n_classes} classes: {sorted_original_labels}")
        self._log(f"  Internal label mapping: {self._original_to_internal_label_map}")

    def _convert_to_internal_labels(self, y_original):
        
        return np.array([self._original_to_internal_label_map.get(str(label).strip(), -1) if str(label).strip() != 'unlabeled' else -1 for label in y_original])

    def _convert_to_original_labels(self, y_internal):
        return np.array([self._internal_to_original_label_map.get(label, 'unlabeled') if label != -1 else 'unlabeled' for label in y_internal])

    def _build_autoencoder(self, input_dim):
        
        input_layer = Input(shape=(input_dim,))
        encoded1 = Dense(min(64, input_dim * 2), activation='relu')(input_layer)
        encoded1 = Dropout(0.2)(encoded1)
        encoded2 = Dense(min(32, input_dim), activation='relu')(encoded1)
        encoded2 = Dropout(0.2)(encoded2)
        encoded = Dense(self.encoding_dim, activation='relu')(encoded2) 

        decoded2 = Dense(min(32, input_dim), activation='relu')(encoded)
        decoded1 = Dense(min(64, input_dim * 2), activation='relu')(decoded2)
        decoded = Dense(input_dim, activation='linear')(decoded1) 

        self.autoencoder = Model(input_layer, decoded)
        self.encoder = Model(input_layer, encoded)
        self.autoencoder.compile(optimizer='adam', loss='mse')

    def _build_improved_similarity_matrix(self, X_latent_combined, method='adaptive_rbf'):
        
        n_samples = X_latent_combined.shape[0]
        if n_samples <= 1:
            return np.zeros((n_samples, n_samples))

        knn_k = min(self.knn_neighbors, n_samples - 1) 
        if knn_k < 1: 
            return np.zeros((n_samples, n_samples))

        if method == 'adaptive_rbf':
            nn = NearestNeighbors(n_neighbors=knn_k)
            nn.fit(X_latent_combined)
            distances, _ = nn.kneighbors(X_latent_combined)
            bandwidths = np.maximum(distances[:, -1], 1e-9) 

            dist_matrix = euclidean_distances(X_latent_combined)

            W = np.zeros((n_samples, n_samples))
            for i in range(n_samples):
                for j in range(n_samples):
                    if i != j:
                        sigma = np.sqrt(bandwidths[i] * bandwidths[j])
                        if sigma > 0:
                            W[i, j] = np.exp(-dist_matrix[i, j]**2 / (2 * sigma**2))

        elif method == 'knn_graph':
            nn = NearestNeighbors(n_neighbors=knn_k)
            nn.fit(X_latent_combined)
            distances, indices = nn.kneighbors(X_latent_combined)

            W = np.zeros((n_samples, n_samples))
            for i in range(n_samples):
                for j_idx, j in enumerate(indices[i]):
                    if i != j:
                        W[i, j] = 1.0 / (1.0 + distances[i, j_idx])
                        W[j, i] = W[i, j] # Ensure symmetry

        else: 
            sigma = np.std(euclidean_distances(X_latent_combined))
            if sigma == 0:
                sigma = 1e-6 
            W = np.exp(-euclidean_distances(X_latent_combined)**2 / (2 * sigma**2))

    
        row_sums = W.sum(axis=1, keepdims=True)
        W = np.divide(W, row_sums, out=np.zeros_like(W), where=row_sums!=0)

        return W

    def _adaptive_confidence_threshold(self, F_unlabeled, initial_threshold):
        
        if not self.adaptive_threshold or F_unlabeled.shape[0] == 0:
            return initial_threshold

        max_probs = np.max(F_unlabeled, axis=1)

        p75 = np.percentile(max_probs, 75)
        p85 = np.percentile(max_probs, 85)
        p90 = np.percentile(max_probs, 90)

        min_samples_for_higher_threshold = max(3, int(F_unlabeled.shape[0] * 0.05))

        if np.sum(max_probs >= p90) >= min_samples_for_higher_threshold:
            return p90
        elif np.sum(max_probs >= p85) >= min_samples_for_higher_threshold:
            return p85
        elif np.sum(max_probs >= p75) >= min_samples_for_higher_threshold:
            return p75
        else:
            return max(initial_threshold, np.percentile(max_probs, 60))

    def _calculate_outlier_reconstruction_iqr_multiplier(self):

        if self.IR is None or self.F is None or self.N is None or self.N == 0:
            self._log("  Warning: IR, F, or N not set for dynamic outlier_reconstruction_iqr_multiplier. Using default 1.5.")
            return 1.5

        log2_1_plus_ir = np.log2(max(1e-9, 1 + self.IR))
        sqrt_f_div_n = np.sqrt(max(1e-9, self.F / self.N))

        return 1 + log2_1_plus_ir * sqrt_f_div_n

    def _calculate_inappropriateness_f1_weight(self):

        if self.IR is None:
            self._log("  Warning: IR not set for dynamic inappropriateness_f1_weight. Using default 0.7.")
            return 0.7

        denominator = 1 + np.log2(max(1e-9, self.IR + 1))
        return 1 / denominator

    def _calculate_alpha_lp(self):

        if self.F is None or self.n_classes is None or self.IR is None:
            return 0.8

        log_1_plus_f = np.log(max(1e-9, 1 + self.F))
        log_1_plus_c = np.log(max(1e-9, 1 + self.n_classes))
        log_1_plus_r = np.log(max(1e-9, 1 + self.IR))

        denominator = log_1_plus_c + log_1_plus_r
        if denominator == 0: 
            return 0.8

        return log_1_plus_f / denominator

    def _calculate_f1_weight_for_confidence(self):

        if self.F is None:
            self._log("  Warning: F not set for dynamic f1_weight_for_confidence. Using default 0.7.")
            return 0.7

        denominator = 1 + np.log(max(1e-9, 1 + self.F))
        return 1 / denominator

    def _calculate_inappropriateness_threshold(self):

        if self.IR is None or self.F is None or self.n_classes is None:
            self._log("  Warning: IR, F, or n_classes not set for dynamic inappropriateness_score_threshold.")
            return 0.2

        log_1_plus_r = np.log(max(1e-9, 1 + self.IR))
        log_denominator = np.log(max(1e-9, (1 + self.F) * (1 + self.n_classes) * (1 + self.IR)))

        if log_denominator == 0:
            return 0.2 # Fallback if denominator is zero

        return 0.2 * log_1_plus_r / log_denominator


    def phase1_data_cleaning(self, X_labeled_train, y_labeled_train, X_unlabeled_train):

        self._log("\n Phase 1 ")

        self._determine_label_mapping(y_labeled_train)
        y_labeled_train_internal = self._convert_to_internal_labels(y_labeled_train)

        X_labeled_train_scaled = self.scaler.fit_transform(X_labeled_train)
        X_unlabeled_train_scaled = self.scaler.transform(X_unlabeled_train)
        X_full_train_scaled = np.vstack([X_labeled_train_scaled, X_unlabeled_train_scaled])

        self._build_autoencoder(X_full_train_scaled.shape[1])
        early_stopping_callback = EarlyStopping(monitor='val_loss', patience=10,
                                               restore_best_weights=True, verbose=0)
        self.autoencoder.fit(X_full_train_scaled, X_full_train_scaled,
                             epochs=200, batch_size=32, verbose=0, validation_split=0.1,
                             callbacks=[early_stopping_callback])

        X_labeled_latent = self.encoder.predict(X_labeled_train_scaled, verbose=0)
        X_unlabeled_latent = self.encoder.predict(X_unlabeled_train_scaled, verbose=0)

        X_labeled_reconstructed = self.autoencoder.predict(X_labeled_train_scaled, verbose=0)
        reconstruction_errors = np.mean(np.square(X_labeled_train_scaled - X_labeled_reconstructed), axis=1)

        q75, q25 = np.percentile(reconstruction_errors, [75, 25])
        iqr = q75 - q25

        error_threshold = q75 + self.outlier_reconstruction_iqr_multiplier_dynamic * iqr
        self._log(f"  Threshold: {error_threshold:.4f} "
                  f"(using multiplier {self.outlier_reconstruction_iqr_multiplier_dynamic:.3f})")

        temp_classifier = RandomForestClassifier(
            n_estimators=100,
            max_depth=None,
            min_samples_split=2,
            class_weight='balanced',
            random_state=42
        )
        temp_classifier.fit(X_labeled_train_scaled, y_labeled_train_internal)
        predictions_internal = temp_classifier.predict(X_labeled_train_scaled)
        probabilities = temp_classifier.predict_proba(X_labeled_train_scaled)

        outlier_flags = np.zeros(X_labeled_train.shape[0], dtype=bool)
        self.original_labeled_log = {}

        for i in range(X_labeled_train.shape[0]):
            true_label_internal = y_labeled_train_internal[i]

            if not (0 <= true_label_internal < probabilities.shape[1]):
                continue

            predicted_label_internal = predictions_internal[i]
            conf_true_label = probabilities[i, true_label_internal]
            max_conf = np.max(probabilities[i])

            is_mislabeled = (predicted_label_internal != true_label_internal) and (max_conf > self.mislabeled_conf_threshold)
            is_very_low_confidence = (conf_true_label < self.low_conf_threshold)
            is_ae_outlier = (reconstruction_errors[i] > error_threshold)

            is_ambiguous = False
            if probabilities.shape[1] > 1:
                # Get the second highest probability
                sorted_probs = np.sort(probabilities[i])[::-1]
                second_max_conf = sorted_probs[1]
                is_ambiguous = (max_conf - second_max_conf < self.ambiguous_diff_threshold)

            if is_mislabeled or is_very_low_confidence or is_ae_outlier or is_ambiguous:
                outlier_flags[i] = True
                self.original_labeled_log[i] = {
                    'original_label': self._internal_to_original_label_map[true_label_internal],
                    'reason': []
                }
                if is_mislabeled: self.original_labeled_log[i]['reason'].append('mislabeled')
                if is_very_low_confidence: self.original_labeled_log[i]['reason'].append('very_low_confidence')
                if is_ae_outlier: self.original_labeled_log[i]['reason'].append('ae_outlier')
                if is_ambiguous: self.original_labeled_log[i]['reason'].append('ambiguous_prediction')

        X_labeled_cleaned = X_labeled_train[~outlier_flags]
        y_labeled_cleaned = y_labeled_train[~outlier_flags]

        X_unlabeled_expanded = np.vstack([X_unlabeled_train, X_labeled_train[outlier_flags]])

        X_latent_labeled_cleaned = self.encoder.predict(self.scaler.transform(X_labeled_cleaned), verbose=0)
        X_latent_unlabeled_expanded = self.encoder.predict(self.scaler.transform(X_unlabeled_expanded), verbose=0)

        self._log(f"  Moved {np.sum(outlier_flags)} instances from labeled to unlabeled pool.")
        self._log(f"  Labeled data remaining: {X_labeled_cleaned.shape[0]} samples.")
        self._log(f"  Unlabeled data (expanded): {X_unlabeled_expanded.shape[0]} samples.")

        return (X_labeled_cleaned, y_labeled_cleaned, X_unlabeled_expanded,
                X_latent_labeled_cleaned, X_latent_unlabeled_expanded)

    def phase2_semi_supervised_training(self, X_labeled_cleaned, y_labeled_cleaned, X_unlabeled_expanded,
                                         X_latent_labeled_cleaned, X_latent_unlabeled_expanded):

        self._log("\n Phase 2")

        y_labeled_cleaned_internal = self._convert_to_internal_labels(y_labeled_cleaned)

        X_latent_combined = np.vstack([X_latent_labeled_cleaned, X_latent_unlabeled_expanded])
        n_labeled = X_latent_labeled_cleaned.shape[0]
        n_total_samples = X_latent_combined.shape[0]

        if n_total_samples == 0 or n_labeled == 0:
            self._log("  Insufficient samples for SSL training.")
            return None

        self.alpha_lp_dynamic = self._calculate_alpha_lp()
        self._log(f"  Using dynamic alpha_lp: {self.alpha_lp_dynamic:.4f}")

        F = np.ones((n_total_samples, self.n_classes)) / self.n_classes

        Y_fixed = np.zeros((n_total_samples, self.n_classes))

        for i, label_internal in enumerate(y_labeled_cleaned_internal):
            if 0 <= label_internal < self.n_classes:
                Y_fixed[i, label_internal] = 1.0
                F[i, :] = 0.0 
                F[i, label_internal] = 1.0 
            else:
                self._log(f"  Labeled sample {i} has an unknown internal label {label_internal}")

        W = self._build_improved_similarity_matrix(X_latent_combined, method='adaptive_rbf')

        F_prev_prev = F.copy()
        momentum = 0.1 

        for iteration in range(self.max_lp_iters):
            F_prev = F.copy()
            # Use dynamic alpha_lp
            F_new = self.alpha_lp_dynamic * (W @ F) + (1 - self.alpha_lp_dynamic) * Y_fixed

            if iteration > 0:
                F_new = F_new + momentum * (F_prev - F_prev_prev)

            F_new[:n_labeled] = Y_fixed[:n_labeled]

            row_sums_F = F_new.sum(axis=1, keepdims=True)
            F = np.divide(F_new, row_sums_F, out=np.zeros_like(F_new), where=row_sums_F!=0)

            error = np.linalg.norm(F - F_prev, ord='fro') / (np.linalg.norm(F_prev, ord='fro') + 1e-9)

            if error < self.lp_epsilon:
                self._log(f"  DLP converged at iteration {iteration+1}")
                break

            F_prev_prev = F_prev.copy()
        else:
            self._log(f"  DLP finished after {iteration+1} iterations (max iterations reached).")

        predicted_probs_unlabeled = F[n_labeled:]
        adaptive_threshold = self._adaptive_confidence_threshold(predicted_probs_unlabeled, self.confidence_threshold)
        self._log(f"  Initial confidence threshold: {self.confidence_threshold:.3f}, Adaptive threshold: {adaptive_threshold:.3f}")

        pseudo_labels_unlabeled_internal = np.argmax(predicted_probs_unlabeled, axis=1)
        max_probs_unlabeled = np.max(predicted_probs_unlabeled, axis=1)
        high_confidence_mask = max_probs_unlabeled >= adaptive_threshold

        self._log(f"  Samples above adaptive threshold: {np.sum(high_confidence_mask)}")

        X_pseudo_labeled_scaled = np.empty((0, X_labeled_cleaned.shape[1]))
        y_pseudo_labeled_original = np.array([])

        if np.sum(high_confidence_mask) > 0:
            X_pseudo_labeled_raw = X_unlabeled_expanded[high_confidence_mask]
            X_pseudo_labeled_scaled = self.scaler.transform(X_pseudo_labeled_raw)
            y_pseudo_labeled_internal = pseudo_labels_unlabeled_internal[high_confidence_mask]
            y_pseudo_labeled_original = self._convert_to_original_labels(y_pseudo_labeled_internal)
            self._log(f"  Selected {len(y_pseudo_labeled_original)} high-confidence pseudo-labels.")

            unique_pseudo, counts_pseudo = np.unique(y_pseudo_labeled_original, return_counts=True)
            self._log(f"  Pseudo-label distribution: {dict(zip(unique_pseudo, counts_pseudo))}")
        else:
            self._log("  No high-confidence pseudo-labels found.")

        X_combined_train_scaled = self.scaler.transform(X_labeled_cleaned)
        y_combined_train_original = y_labeled_cleaned

        if X_pseudo_labeled_scaled.shape[0] > 0:
            X_combined_train_scaled = np.vstack([X_combined_train_scaled, X_pseudo_labeled_scaled])
            y_combined_train_original = np.hstack([y_combined_train_original, y_pseudo_labeled_original])

        unique_labels_in_combined = np.unique(y_combined_train_original)
        if len(unique_labels_in_combined) < 2:
            self._log("  Insufficient unique classes (less than 2) for training final classifier.")
            self.final_classifier = None
            return None

        y_combined_train_internal = self._convert_to_internal_labels(y_combined_train_original)

        class_weights_dict = compute_class_weight(
            class_weight='balanced',
            classes=unique_labels_in_combined,
            y=y_combined_train_original
        )
        classifier_class_weight_map = {cls: weight for cls, weight in zip(unique_labels_in_combined, class_weights_dict)}

        if self.final_classifier_type == 'random_forest':
            self._log("  Random Forest Classifier:")
            self.final_classifier = RandomForestClassifier(
                n_estimators=200,
                max_depth=None,
                min_samples_split=2,
                class_weight=classifier_class_weight_map,
                random_state=42
            )
        elif self.final_classifier_type == 'xgboost':
            self._log("  XGBoost Classifier:")

            unique_internal_labels_present = np.unique(y_combined_train_internal)

            temp_xgb_label_map = {label: i for i, label in enumerate(sorted(unique_internal_labels_present))}

            y_xgb_train_indexed = np.array([temp_xgb_label_map[label] for label in y_combined_train_internal])

            temp_sample_weights = np.array([
                classifier_class_weight_map[self._internal_to_original_label_map[label]]
                for label in y_combined_train_internal
            ])


            default_xgb_params = {
                'objective': 'multi:softmax', 
                'num_class': len(unique_internal_labels_present), 
                'eval_metric': 'mlogloss', 
                'seed': 42
            }
            xgb_final_params = {**default_xgb_params, **self.xgboost_params}

            self.final_classifier = xgb.XGBClassifier(
                **xgb_final_params
            )
            self.final_classifier._temp_xgb_label_map = temp_xgb_label_map
            self.final_classifier._reverse_temp_xgb_label_map = {v: k for k, v in temp_xgb_label_map.items()}


        elif self.final_classifier_type == 'stacked_ensemble':
            self._log(" Stacked Ensemble:")
            # Base models as specified
            estimators = [
                ('knn', KNeighborsClassifier(n_neighbors=self.stacked_ensemble_params.get('knn_k', 3))),
                ('c45', DecisionTreeClassifier(
                    criterion='entropy',
                    max_depth=self.stacked_ensemble_params.get('c45_max_depth', None), 
                    min_samples_leaf=self.stacked_ensemble_params.get('c45_min_samples_leaf', 2),
                    random_state=42
                )),
                ('svc', SVC(
                    C=self.stacked_ensemble_params.get('svc_C', 1.0),
                    tol=self.stacked_ensemble_params.get('svc_tol', 0.001),
                    degree=self.stacked_ensemble_params.get('svc_degree', 1),
                    kernel='poly',
                    probability=True, 
                    random_state=42
                )),
                ('naive_bayes', GaussianNB())
            ]

            meta_classifier = LogisticRegression(
                solver='liblinear', 
                # multi_class='auto', 
                random_state=42,
                class_weight='balanced' 
            )

            self.final_classifier = StackingClassifier(
                estimators=estimators,
                final_estimator=meta_classifier,
                cv=self.stacked_ensemble_params.get('cv', 5), 
                passthrough=self.stacked_ensemble_params.get('passthrough', False),
                n_jobs=-1 
            )

        else:
            raise ValueError(f"Unknown classifier type: {self.final_classifier_type}")

        if self.final_classifier_type == 'xgboost':
            self.final_classifier.fit(X_combined_train_scaled, y_xgb_train_indexed, sample_weight=temp_sample_weights)
        else:
            self.final_classifier.fit(X_combined_train_scaled, y_combined_train_original)

        self._log(f"  Classifier ({self.final_classifier_type}) trained successfully.")
        return F

    def phase3_class_appropriateness_assessment(self, X_test, y_test,
                                                X_labeled_cleaned_original, y_labeled_cleaned_original,
                                                X_unlabeled_expanded_original, F_lp_matrix):

        self._log("\n Phase 3")

        inappropriateness_score_threshold = self._calculate_inappropriateness_threshold()
        self._log(f"  Dynamic inappropriateness score threshold: {inappropriateness_score_threshold:.4f}")

        if self.final_classifier == None:
            self._log("  Cannot perform assessment")
            return

        X_test_scaled = self.scaler.transform(X_test)

        if self.final_classifier_type == 'xgboost':
            y_pred_baseline_internal_xgb_indexed = self.final_classifier.predict(X_test_scaled)
            y_pred_baseline_internal = np.array([
                self.final_classifier._reverse_temp_xgb_label_map.get(idx, -1) 
                for idx in y_pred_baseline_internal_xgb_indexed
            ])
            y_pred_baseline = self._convert_to_original_labels(y_pred_baseline_internal)
        else:
            y_pred_baseline = self.final_classifier.predict(X_test_scaled)

        common_classes_test = np.array([c for c in np.unique(y_test) if c in self.all_original_classes])
        test_mask_common = np.isin(y_test, common_classes_test)

        y_test_filtered = y_test[test_mask_common]
        y_pred_baseline_filtered = y_pred_baseline[test_mask_common]

        baseline_f1_weighted = 0.0 
        baseline_accuracy = 0.0
        baseline_precision_weighted = 0.0
        baseline_recall_weighted = 0.0
        baseline_report = {}

        if len(np.unique(y_test_filtered)) >= 2:
            baseline_report = classification_report(
                y_test_filtered, y_pred_baseline_filtered,
                labels=sorted(self.all_original_classes.tolist()), 
                output_dict=True,
                zero_division=0
            )
            baseline_accuracy = accuracy_score(y_test_filtered, y_pred_baseline_filtered)
            if 'weighted avg' in baseline_report: 
                baseline_f1_weighted = baseline_report['weighted avg']['f1-score'] # Changed from macro
                baseline_precision_weighted = baseline_report['weighted avg']['precision']
                baseline_recall_weighted = baseline_report['weighted avg']['recall']
        else:
            self._log("  Not enough common classes in filtered test set. ")

        self._log(f"  Baseline Performance (on common classes): Weighted F1={baseline_f1_weighted:.4f}, Accuracy={baseline_accuracy:.4f}") 

        self._log("\n  Detailed Analysis:")
        class_appropriateness = {}
        for cls_original in sorted(self.all_original_classes.tolist()):
            cls_str = str(cls_original) 
            f1_score_cls = baseline_report.get(cls_str, {}).get('f1-score', 0.0)
            precision_cls = baseline_report.get(cls_str, {}).get('precision', 0.0)
            recall_cls = baseline_report.get(cls_str, {}).get('recall', 0.0)

            outlier_count = sum(1 for v in self.original_labeled_log.values() if v['original_label'] == cls_original)

            original_count = np.sum(y_labeled_cleaned_original == cls_original) + outlier_count 

            outlier_ratio = outlier_count / original_count if original_count > 0 else 0

            inappropriateness_score = (1 - f1_score_cls) * self.inappropriateness_f1_weight_dynamic + \
                                      outlier_ratio * self.inappropriateness_outlier_weight
            class_appropriateness[cls_original] = {
                'f1_score_baseline': f1_score_cls,
                'precision_baseline': precision_cls,
                'recall_baseline': recall_cls,
                'outlier_ratio': outlier_ratio,
                'inappropriateness_score': inappropriateness_score
            }

        sorted_inappropriate = sorted(
            class_appropriateness.items(),
            key=lambda x: x[1]['inappropriateness_score'],
            reverse=True
        )

        self._log("  Classes ranked by inappropriateness:")
        for cls, metrics in sorted_inappropriate:
            self._log(f"    Class {cls}: F1={metrics['f1_score_baseline']:.4f}, "
                      f"Outlier Ratio={metrics['outlier_ratio']:.4f}, "
                      f"Score={metrics['inappropriateness_score']:.4f}")

        candidate_classes = [
            cls for cls, metrics in sorted_inappropriate
            if metrics['inappropriateness_score'] > inappropriateness_score_threshold
        ]
        self._log(f"\n  Candidate classes for removal (score > {inappropriateness_score_threshold:.4f}): {sorted([c for c in candidate_classes])}")

        evaluation_results = []
        n_labeled_cleaned = X_labeled_cleaned_original.shape[0]
        F_unlabeled_part_baseline = F_lp_matrix[n_labeled_cleaned:]
        pseudo_labels_unlabeled_all_internal_baseline = np.argmax(F_unlabeled_part_baseline, axis=1)
        max_probs_unlabeled_all_baseline = np.max(F_unlabeled_part_baseline, axis=1)
        adaptive_threshold_baseline = self._adaptive_confidence_threshold(F_unlabeled_part_baseline, self.confidence_threshold)
        pseudo_label_mask_baseline = (max_probs_unlabeled_all_baseline >= adaptive_threshold_baseline) & \
                                     np.isin(self._convert_to_original_labels(pseudo_labels_unlabeled_all_internal_baseline), self.all_original_classes)
        avg_pseudo_conf_baseline = np.mean(max_probs_unlabeled_all_baseline[pseudo_label_mask_baseline]) if np.sum(pseudo_label_mask_baseline) > 0 else 0.0

        f1_weight_for_confidence_dynamic = self._calculate_f1_weight_for_confidence()
        pseudo_conf_weight_for_confidence = 1 - f1_weight_for_confidence_dynamic

        combined_confidence_baseline = (baseline_f1_weighted * f1_weight_for_confidence_dynamic) + \
                                       (avg_pseudo_conf_baseline * pseudo_conf_weight_for_confidence) # Changed from macro

        evaluation_results.append({
            'classes_kept': tuple(sorted(self.all_original_classes.tolist())),
            'f1_weighted': baseline_f1_weighted, # Changed from macro_f1
            'accuracy': baseline_accuracy,
            'precision_weighted': baseline_precision_weighted,
            'recall_weighted': baseline_recall_weighted,
            'confidence': combined_confidence_baseline,
            'removed_classes': (),
            'report': baseline_report
        })

        if not candidate_classes:
            self._log("\n  No candidate classes for removal based on threshold.")
        else:
            self._log("\n  Evaluating class removal combinations:")
            max_r = min(len(candidate_classes), self.max_removal_combinations)
            if max_r < len(candidate_classes):
                self._log(f" Limiting evaluation to combinations of up to {max_r} classes.")

            n_labeled_cleaned = X_labeled_cleaned_original.shape[0]

            for r in range(1, max_r + 1):
                self._log(f"  Evaluating combinations of {r} classes for removal.")
                for removal_combo in itertools.combinations(candidate_classes, r):
                    kept_classes = [c for c in self.all_original_classes if c not in removal_combo]
                    if len(kept_classes) < 2:
                        self._log(f"    Skipping combination {removal_combo}: Less than 2 classes remaining.")
                        continue

                    train_mask_cleaned_labeled = np.isin(y_labeled_cleaned_original, kept_classes)
                    X_train_labeled_filt = X_labeled_cleaned_original[train_mask_cleaned_labeled]
                    y_train_labeled_filt = y_labeled_cleaned_original[train_mask_cleaned_labeled]

                    if len(np.unique(y_train_labeled_filt)) < 2:
                        continue

                    F_unlabeled_part = F_lp_matrix[n_labeled_cleaned:]
                    pseudo_labels_unlabeled_all_internal = np.argmax(F_unlabeled_part, axis=1)
                    max_probs_unlabeled_all = np.max(F_unlabeled_part, axis=1)
                    adaptive_threshold_subset = self._adaptive_confidence_threshold(
                        F_unlabeled_part, self.confidence_threshold
                    )
                    pseudo_label_mask = (max_probs_unlabeled_all >= adaptive_threshold_subset) & \
                                        np.isin(self._convert_to_original_labels(pseudo_labels_unlabeled_all_internal), kept_classes)

                    X_pseudo_labeled_filt_raw = X_unlabeled_expanded_original[pseudo_label_mask]
                    y_pseudo_labeled_filt = self._convert_to_original_labels(pseudo_labels_unlabeled_all_internal[pseudo_label_mask])

                    avg_pseudo_conf_current = np.mean(max_probs_unlabeled_all[pseudo_label_mask]) if np.sum(pseudo_label_mask) > 0 else 0.0

                    X_combined_train_filt_scaled = self.scaler.transform(X_train_labeled_filt)
                    y_combined_train_filt = y_train_labeled_filt

                    if X_pseudo_labeled_filt_raw.shape[0] > 0:
                        X_combined_train_filt_scaled = np.vstack([X_combined_train_filt_scaled, self.scaler.transform(X_pseudo_labeled_filt_raw)])
                        y_combined_train_filt = np.hstack([y_combined_train_filt, y_pseudo_labeled_filt])

                    if len(np.unique(y_combined_train_filt)) < 2:
                        continue

                    test_mask = np.isin(y_test, kept_classes)
                    X_test_filt = X_test[test_mask]
                    y_test_filt = y_test[test_mask]

                    if len(np.unique(y_test_filt)) < 2:
                        continue

                    unique_labels_subset = np.unique(y_combined_train_filt)
                    if len(unique_labels_subset) > 1:
                        class_weights_subset = compute_class_weight(
                            class_weight='balanced',
                            classes=unique_labels_subset,
                            y=y_combined_train_filt
                        )
                        classifier_class_weight_map_subset = {cls: weight for cls, weight in zip(unique_labels_subset, class_weights_subset)}
                    else:
                        continue

                    temp_clf = None
                    if self.final_classifier_type == 'random_forest':
                        temp_clf = RandomForestClassifier(
                            n_estimators=100,
                            class_weight=classifier_class_weight_map_subset,
                            random_state=42
                        )
                    elif self.final_classifier_type == 'xgboost':
                        y_combined_train_filt_internal = self._convert_to_internal_labels(y_combined_train_filt)

                        unique_internal_labels_present_subset = np.unique(y_combined_train_filt_internal)

                        temp_xgb_label_map_subset = {label: i for i, label in enumerate(sorted(unique_internal_labels_present_subset))}

                        y_xgb_train_indexed_subset = np.array([temp_xgb_label_map_subset[label] for label in y_combined_train_filt_internal])

                        sample_weights_subset = np.array([
                            classifier_class_weight_map_subset[self._internal_to_original_label_map[label]]
                            for label in y_combined_train_filt_internal
                        ])

                        default_xgb_params_subset = {
                            'objective': 'multi:softmax',
                            'num_class': len(unique_internal_labels_present_subset), 
                            'eval_metric': 'mlogloss',
                            'use_label_encoder': False,
                            'seed': 42
                        }
                        xgb_final_params_subset = {**default_xgb_params_subset, **self.xgboost_params}

                        temp_clf = xgb.XGBClassifier(**xgb_final_params_subset)
                        temp_clf._temp_xgb_label_map = temp_xgb_label_map_subset
                        temp_clf._reverse_temp_xgb_label_map = {v: k for k, v in temp_xgb_label_map_subset.items()}

                    elif self.final_classifier_type == 'stacked_ensemble':
                        estimators_subset = [
                            ('knn', KNeighborsClassifier(n_neighbors=self.stacked_ensemble_params.get('knn_k', 3))),
                            ('c45', DecisionTreeClassifier(
                                criterion='entropy',
                                max_depth=self.stacked_ensemble_params.get('c45_max_depth', None),
                                min_samples_leaf=self.stacked_ensemble_params.get('c45_min_samples_leaf', 2),
                                random_state=42
                            )),
                            ('svc', SVC(
                                C=self.stacked_ensemble_params.get('svc_C', 1.0),
                                tol=self.stacked_ensemble_params.get('svc_tol', 0.001),
                                degree=self.stacked_ensemble_params.get('svc_degree', 1),
                                kernel='poly',
                                probability=True,
                                random_state=42
                            )),
                            ('naive_bayes', GaussianNB())
                        ]

                        meta_classifier_subset = LogisticRegression(
                            solver='liblinear',
                            # multi_class='auto', # Removed to avoid FutureWarning
                            random_state=42,
                            class_weight='balanced'
                        )

                        temp_clf = StackingClassifier(
                            estimators=estimators_subset,
                            final_estimator=meta_classifier_subset,
                            cv=self.stacked_ensemble_params.get('cv', 5),
                            passthrough=self.stacked_ensemble_params.get('passthrough', False),
                            n_jobs=-1
                        )

                    try:
                        if self.final_classifier_type == 'xgboost':
                            temp_clf.fit(X_combined_train_filt_scaled, y_xgb_train_indexed_subset, sample_weight=sample_weights_subset)
                            y_pred_internal_xgb_indexed = temp_clf.predict(self.scaler.transform(X_test_filt))
                            y_pred_internal = np.array([
                                temp_clf._reverse_temp_xgb_label_map.get(idx, -1) 
                                for idx in y_pred_internal_xgb_indexed
                            ])
                            y_pred = self._convert_to_original_labels(y_pred_internal)
                        else:
                            temp_clf.fit(X_combined_train_filt_scaled, y_combined_train_filt)
                            y_pred = temp_clf.predict(self.scaler.transform(X_test_filt))

                        report = classification_report(y_test_filt, y_pred, output_dict=True, zero_division=0, labels=sorted(kept_classes))
                        f1_weighted = report['weighted avg']['f1-score']
                        accuracy = report['accuracy']
                        precision_weighted = report['weighted avg']['precision']
                        recall_weighted = report['weighted avg']['recall']

                        combined_confidence_current = (f1_weighted * f1_weight_for_confidence_dynamic) + \
                                                      (avg_pseudo_conf_current * pseudo_conf_weight_for_confidence)

                        evaluation_results.append({
                            'classes_kept': tuple(sorted(kept_classes)),
                            'f1_weighted': f1_weighted, 
                            'accuracy': accuracy,
                            'precision_weighted': precision_weighted,
                            'recall_weighted': recall_weighted,
                            'confidence': combined_confidence_current,
                            'removed_classes': tuple(sorted(removal_combo)),
                            'report': report
                        })
                    except Exception as e:
                        continue

        results_df = pd.DataFrame(evaluation_results).sort_values('f1_weighted', ascending=False) 
        self._log("\n  Evaluation Results (sorted by Weighted F1):") 
        if not results_df.empty:
            results_df['confidence'] = results_df['confidence'].apply(lambda x: f"{x:.2%}")
            results_df['f1_weighted'] = results_df['f1_weighted'].apply(lambda x: f"{x:.4f}") 
            results_df['accuracy'] = results_df['accuracy'].apply(lambda x: f"{x:.4f}")
            results_df['precision_weighted'] = results_df['precision_weighted'].apply(lambda x: f"{x:.4f}")
            results_df['recall_weighted'] = results_df['recall_weighted'].apply(lambda x: f"{x:.4f}")
            self._log(results_df[['classes_kept', 'removed_classes', 'f1_weighted', 'accuracy', 'precision_weighted', 'recall_weighted', 'confidence']].to_string()) 
        else:
            self._log("  No evaluation.")
            return

        best_config = results_df.iloc[0]
        if best_config['removed_classes']:
            self._log(f"\n Remove classes {best_config['removed_classes']}")
            # Recalculate raw f1_weighted for gain if needed, as it's a string in DF now
            baseline_f1_weighted_raw = evaluation_results[0]['f1_weighted'] # Get raw value from original list
            self._log(f" Expected improvement: +{float(best_config['f1_weighted']) - baseline_f1_weighted_raw:.4f} Weighted F1") # Changed log message
            self._log(f" Confidence in recommendation: {best_config['confidence']}")
        else:
            self._log("\n Keep all original classes (no significant improvement by removal).")

        self._log(f"\n Classification Report: (Kept Classes: {best_config['classes_kept']}) ---")
        report_str = "\n"
        report_dict = best_config['report']
        if report_dict:
            report_labels = sorted([k for k in report_dict.keys() if k not in ['accuracy', 'macro avg', 'weighted avg', 'samples avg']])

            report_str += f"{'':<10}{'precision':>11}{'recall':>11}{'f1-score':>11}{'support':>9}\n"
            report_str += f"{'':<10}{'-----------':>11}{'-----------':>11}{'-----------':>11}{'-------':>9}\n"

            for label in report_labels:
                metrics = report_dict[label]
                report_str += f"{str(label):<10}{metrics['precision']:>11.4f}{metrics['recall']:>11.4f}{metrics['f1-score']:>11.4f}{metrics['support']:>9.0f}\n"

            report_str += "\n" # Add a newline for separation
            if 'macro avg' in report_dict:
                macro_avg = report_dict['macro avg']
                report_str += f"{'macro avg':<10}{macro_avg['precision']:>11.4f}{macro_avg['recall']:>11.4f}{macro_avg['f1-score']:>11.4f}{macro_avg['support']:>9.0f}\n"
            if 'weighted avg' in report_dict: # Added weighted avg
                weighted_avg = report_dict['weighted avg']
                report_str += f"{'weighted avg':<10}{weighted_avg['precision']:>11.4f}{weighted_avg['recall']:>11.4f}{weighted_avg['f1-score']:>11.4f}{weighted_avg['support']:>9.0f}\n"

            if 'accuracy' in report_dict:
                report_str += f"{'accuracy':<10}{'':>11}{'':>11}{report_dict['accuracy']:>11.4f}{report_dict.get('samples avg',{}).get('support',0):>9.0f}\n"
        else:
            report_str += "  No classification report available."
        self._log(report_str)


def main():
    try:
        train_df = pd.read_csv("/train_dataset.csv")
        test_df = pd.read_csv("/test_dataset.csv")
    except FileNotFoundError:
        print("Error")
        sys.exit(1)

    # Dataset specific processing
    X_train_full = train_df.iloc[:, :-1].values
    y_train_full = train_df.iloc[:, -1].astype(str).str.strip().values

    X_test = test_df.iloc[:, :-1].values
    y_test = test_df.iloc[:, -1].astype(str).str.strip().values 

    unlabeled_mask = (y_train_full == 'unlabeled')
    X_labeled = X_train_full[~unlabeled_mask]
    y_labeled = y_train_full[~unlabeled_mask]
    X_unlabeled = X_train_full[unlabeled_mask]

    print(f" Labeled Data Samples: {X_labeled.shape[0]}")
    print(f" Unlabeled Data Samples: {X_unlabeled.shape[0]}")
    print(f"Test Samples: {X_test.shape[0]}")
    print(f"Classes in Labeled Training Set: {np.unique(y_labeled)}")
    print(f"Classes in Test Set: {np.unique(y_test)}")

    N_total_instances = X_train_full.shape[0] 
    F_num_features = X_train_full.shape[1] 
    P_percent_labeled = X_labeled.shape[0] / N_total_instances if N_total_instances > 0 else 0.0

    unique_classes_labeled, counts_labeled = np.unique(y_labeled, return_counts=True)
    if len(counts_labeled) > 1:
        max_count_labeled = np.max(counts_labeled)
        min_count_labeled = np.min(counts_labeled)
        # Ensure min_count_labeled is not zero to avoid division by zero
        IR_calculated = max_count_labeled / min_count_labeled if min_count_labeled > 0 else 1.0
    else:
        IR_calculated = 1.0 

    print(f"Total Instances (N): {N_total_instances}")
    print(f" Number of Features (F): {F_num_features}")
    print(f" Percentage Labeled (P): {P_percent_labeled:.4f}")
    print(f" Imbalance Ratio (IR) from labeled data: {IR_calculated:.2f}")


    print("\n" + "="*50)
    print("\n Random Forest Classifier")
    print("="*50)
    framework_rf = MultiClassFramework(
        encoding_dim=10, 
        confidence_threshold=0.5,
        adaptive_threshold=True,
        knn_neighbors=5,
        mislabeled_conf_threshold=0.75,
        low_conf_threshold=0.25,
        ambiguous_diff_threshold=0.15,
        inappropriateness_outlier_weight=0.2,
        max_removal_combinations=3,
        max_lp_iters=200,
        lp_epsilon=1e-6,
        IR=IR_calculated,
        N=N_total_instances,
        F=F_num_features,
        P=P_percent_labeled,
        verbose=True,
        final_classifier_type='random_forest'
    )
    X_labeled_clean_rf, y_labeled_clean_rf, X_unlabeled_exp_rf, X_latent_labeled_rf, X_latent_unlabeled_rf = \
        framework_rf.phase1_data_cleaning(X_labeled, y_labeled, X_unlabeled)
    F_lp_matrix_rf = framework_rf.phase2_semi_supervised_training(
        X_labeled_clean_rf, y_labeled_clean_rf, X_unlabeled_exp_rf,
        X_latent_labeled_rf, X_latent_unlabeled_rf
    )
    if F_lp_matrix_rf is not None:
        print("\nRandom Forest  Results:")
        framework_rf.phase3_class_appropriateness_assessment(
            X_test, y_test,
            X_labeled_clean_rf, y_labeled_clean_rf, X_unlabeled_exp_rf, F_lp_matrix_rf
        )
    else:
        print("Skipping Phase 3 for Random Forest as Phase 2 did not successfully train a classifier.")


    print("\n" + "="*50)
    print("XGBoost Classifier")
    print("="*50)
    xgboost_specific_params = {
        'eval_metric': 'mlogloss',
        'n_estimators': 200,
        'learning_rate': 0.1,
        'max_depth': 5,
        'subsample': 0.7,
        'colsample_bytree': 0.7
    }
    framework_xgb = MultiClassFramework(
        encoding_dim=10,
        confidence_threshold=0.5,
        adaptive_threshold=True,
        knn_neighbors=5,
        mislabeled_conf_threshold=0.75,
        low_conf_threshold=0.25,
        ambiguous_diff_threshold=0.15,
        inappropriateness_outlier_weight=0.2,
        max_removal_combinations=3,
        max_lp_iters=200,
        lp_epsilon=1e-6,
        IR=IR_calculated,
        N=N_total_instances,
        F=F_num_features,
        P=P_percent_labeled,
        verbose=True,
        final_classifier_type='xgboost',
        xgboost_params=xgboost_specific_params
    )
    X_labeled_clean_xgb, y_labeled_clean_xgb, X_unlabeled_exp_xgb, X_latent_labeled_xgb, X_latent_unlabeled_xgb = \
        framework_xgb.phase1_data_cleaning(X_labeled, y_labeled, X_unlabeled)
    F_lp_matrix_xgb = framework_xgb.phase2_semi_supervised_training(
        X_labeled_clean_xgb, y_labeled_clean_xgb, X_unlabeled_exp_xgb,
        X_latent_labeled_xgb, X_latent_unlabeled_xgb
    )
    if F_lp_matrix_xgb is not None:
        print("\nXGBoost  Results:")
        framework_xgb.phase3_class_appropriateness_assessment(
            X_test, y_test,
            X_labeled_clean_xgb, y_labeled_clean_xgb, X_unlabeled_exp_xgb, F_lp_matrix_xgb
        )
    else:
        print("Skipping Phase 3 for XGBoost as Phase 2 did not successfully train a classifier.")


    print("\n" + "="*50)
    print("Stacked Ensemble")
    print("="*50)
    stacked_ensemble_specific_params = {
        'knn_k': 3,
        'c45_max_depth': None, # C4.5 with pruned tree (implicitly handled by min_samples_leaf and no max_depth)
        'c45_min_samples_leaf': 2,
        'svc_C': 1.0,
        'svc_tol': 0.001,
        'svc_epsilon': 1e-12, # This param is not directly used by SVC, but kept for user's reference
        'svc_degree': 1,
        'cv': 5 # Cross-validation folds for stacking
    }
    framework_stacked = MultiClassFramework(
        encoding_dim=10,
        confidence_threshold=0.5,
        adaptive_threshold=True,
        knn_neighbors=5,
        mislabeled_conf_threshold=0.75,
        low_conf_threshold=0.25,
        ambiguous_diff_threshold=0.15,
        inappropriateness_outlier_weight=0.2,
        max_removal_combinations=3,
        max_lp_iters=200,
        lp_epsilon=1e-6,
        IR=IR_calculated,
        N=N_total_instances,
        F=F_num_features,
        P=P_percent_labeled,
        verbose=True,
        final_classifier_type='stacked_ensemble',
        stacked_ensemble_params=stacked_ensemble_specific_params
    )
    X_labeled_clean_stacked, y_labeled_clean_stacked, X_unlabeled_exp_stacked, X_latent_labeled_stacked, X_latent_unlabeled_stacked = \
        framework_stacked.phase1_data_cleaning(X_labeled, y_labeled, X_unlabeled)
    F_lp_matrix_stacked = framework_stacked.phase2_semi_supervised_training(
        X_labeled_clean_stacked, y_labeled_clean_stacked, X_unlabeled_exp_stacked,
        X_latent_labeled_stacked, X_latent_unlabeled_stacked
    )
    if F_lp_matrix_stacked is not None:
        print("\nStacked Ensemble  Results:")
        framework_stacked.phase3_class_appropriateness_assessment(
            X_test, y_test,
            X_labeled_clean_stacked, y_labeled_clean_stacked, X_unlabeled_exp_stacked, F_lp_matrix_stacked
        )
    else:
        print("Skipping Phase 3 for Stacked Ensemble as Phase 2 did not successfully train a classifier.")


if __name__ == "__main__":
    main()
